{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 文本分类\n",
    "\n",
    "模型 预测\n",
    "\n",
    "常用模型\n",
    "    DAN : 情感分类\n",
    "    双向RNN:\n",
    "    堆叠神经网络\n",
    "    CNN:"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "使用PyTorch模型和TorchText做情感分析\n",
    "检测一段文字的情感是证明还是负面的。数据集：电影评论IMDb数据集\n",
    "\n",
    "模型从简单到复杂：\n",
    "   * Word Averaging 模型\n",
    "   * RNN / LSTM 模型\n",
    "   * CNN 模型"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 准备数据\n",
    "* TorchText中的Field，决定数据怎么被处理\n",
    "* 使用Text Field来处理定义评论，使用Label field处理2中情感类型pos， neg\n",
    "* Text Field中tokenize=‘spacy' 表示我们用SpaCy tokenizer来tokenize英文语句。如果不使用该参数，默认使用空格来分词\n",
    "* 安装SpaCy   `pip install spacy & python -m spacy donwload en\n",
    "* label 有LabelField定义"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch: 1.9.0+cpu\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchtext.datasets import AG_NEWS\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "from torchtext import vocab\n",
    "\n",
    "SEED = 1\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('torch:', torch.__version__)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " len(train_iter): 25000, test: 25000\n"
     ]
    }
   ],
   "source": [
    "from torchtext.datasets import IMDB\n",
    "\n",
    "train_iter, test_iter = IMDB()\n",
    "print(f' len(train_iter): {len(train_iter)}, test: {len(test_iter)}')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VOCAB_SIZE: 100683\n"
     ]
    }
   ],
   "source": [
    "# 使用 build_vocab_from_iterator 预处理\n",
    "tokenizer = get_tokenizer('basic_english')\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "train_iter = IMDB(split='train')\n",
    "def yield_tokens(data_iter):\n",
    "    for _, text in data_iter:\n",
    "        yield tokenizer(text)\n",
    "\n",
    "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"<unk>\"])\n",
    "vocab.set_default_index(vocab[\"<unk>\"])\n",
    "VOCAB_SIZE = len(vocab)\n",
    "print(f'VOCAB_SIZE: {VOCAB_SIZE}')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "# Prepare data processing pipelines\n",
    "text_pipeline = lambda x: vocab(tokenizer(x))\n",
    "label_pipeline = lambda x: 0 if x == 'neg' else 1\n",
    "\n",
    "def collate_batch(batch):\n",
    "    label_list, text_list, offsets = [], [], [0]\n",
    "    for (_label, _text) in batch:\n",
    "         label_list.append(label_pipeline(_label))\n",
    "         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n",
    "         text_list.append(processed_text)\n",
    "         offsets.append(processed_text.size(0))\n",
    "    label_list = torch.tensor(label_list, dtype=torch.int64)\n",
    "    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\n",
    "    text_list = torch.cat(text_list)\n",
    "    return label_list.to(device), text_list.to(device), offsets.to(device)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "#  划分训练，验证集，\n",
    "from torch.utils.data.dataset import random_split\n",
    "from torchtext.data.functional import to_map_style_dataset\n",
    "# Hyperparameters\n",
    "EPOCHS = 10 # epoch\n",
    "LR = 5  # learning rate\n",
    "BATCH_SIZE = 64 # batch size for training\n",
    "\n",
    "total_accu = None\n",
    "train_iter, test_iter = IMDB()\n",
    "\n",
    "train_dataset = to_map_style_dataset(train_iter)\n",
    "test_dataset = to_map_style_dataset(test_iter)\n",
    "num_train = int(len(train_dataset) * 0.8)\n",
    "split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])\n",
    "\n",
    "train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)\n",
    "valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1,\n",
      "        0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1,\n",
      "        0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1]), tensor([ 12, 298,   8,  ...,   2,   2,   2]), tensor([    0,   246,   702,  1650,  1688,  1998,  2234,  2513,  2869,  3020,\n",
      "         3199,  3312,  3737,  4311,  4635,  4779,  5108,  5263,  5570,  5741,\n",
      "         5888,  5961,  6103,  6288,  6722,  6998,  7096,  7282,  7508,  7613,\n",
      "         7771,  8050,  8398,  8534,  8654,  9122,  9266,  9400,  9715,  9798,\n",
      "        10072, 10362, 10513, 10639, 10811, 11564, 11779, 11940, 12156, 12547,\n",
      "        12723, 13430, 13562, 13695, 13859, 14018, 14277, 14892, 15124, 15326,\n",
      "        15570, 15880, 16031, 16163]))\n"
     ]
    }
   ],
   "source": [
    "print(next(train_dataloader._get_iterator()))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "###  Word Averaging模型\n",
    "\n",
    "1. 定义模型:  nn.EmbeddingBag layer plus a linear layer for the classification purpose.\n",
    "2. 关于实现：avg_pool2d做average pooling。（sent_len, embedding_dim) -> (1,embedding_dim)\n",
    "3. note:avg_pool2d的kernel_size是（embedded.shape[1],1），所以句子长度这个维度会被压扁"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "#模型构建\n",
    "class WordAVGModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, output_dim):\n",
    "        super(WordAVGModel,self).__init__()\n",
    "        # TODO 此处有问题\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim) # 单词变成词向量\n",
    "        self.fc = nn.Linear(embedding_dim, output_dim)\n",
    "\n",
    "    def forward(self, text):\n",
    "        embedded = self.embedding(text) # [sent len, batch size, emb dim]\n",
    "        print('shape of embedded', embedded.shape)\n",
    "        # TODO 此处有问题\n",
    "        embedded = embedded.permute(1, 0, 2) # [batch size, sent len, emb dim]\n",
    "        # 逐列求平均\n",
    "        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) # [batch size, embedding_dim]\n",
    "        return self.fc(pooled)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "EMBEDDING_SIZE = 64\n",
    "OUTPUT_DIM = 1\n",
    "model = WordAVGModel(VOCAB_SIZE, EMBEDDING_SIZE, 1).to(device)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 23,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'Vocab' object has no attribute 'vectors'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mAttributeError\u001B[0m                            Traceback (most recent call last)",
      "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_11332/973745355.py\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mprint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvocab\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mvectors\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m      2\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m__getattr__\u001B[1;34m(self, name)\u001B[0m\n\u001B[0;32m   1128\u001B[0m             \u001B[1;32mif\u001B[0m \u001B[0mname\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mmodules\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1129\u001B[0m                 \u001B[1;32mreturn\u001B[0m \u001B[0mmodules\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mname\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1130\u001B[1;33m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001B[0m\u001B[0;32m   1131\u001B[0m             type(self).__name__, name))\n\u001B[0;32m   1132\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mAttributeError\u001B[0m: 'Vocab' object has no attribute 'vectors'"
     ]
    }
   ],
   "source": [
    "print(vocab.vectors)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torchtext\n",
    "vector = torchtext.vocab.pretrained_aliases['glove.6B.100d']\n",
    "model.embedding.weight.data.copy_(vector())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [],
   "source": [
    "UNK_IDX = vocab.get_stoi()[\"<unk>\"]\n",
    "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_SIZE)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OUTPUT_DIM:1, VOCAB_SIZE:100683, EMBEDDING_SIZE:64\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "criterion = nn.BCEWithLogitsLoss()\n",
    "print(f'OUTPUT_DIM:{OUTPUT_DIM}, VOCAB_SIZE:{VOCAB_SIZE}, EMBEDDING_SIZE:{EMBEDDING_SIZE}')\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def train(dataloader):\n",
    "    model.train()\n",
    "    total_acc, total_count = 0, 0\n",
    "    log_interval = 500\n",
    "    start_time = time.time()\n",
    "\n",
    "    for idx, (label, text, offsets) in enumerate(dataloader):\n",
    "        optimizer.zero_grad()\n",
    "        predicted_label = model(text)\n",
    "        loss = criterion(predicted_label, label)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n",
    "        optimizer.step()\n",
    "        total_acc += (predicted_label.argmax(1) == label).sum().item()\n",
    "        total_count += label.size(0)\n",
    "        if idx % log_interval == 0 and idx > 0:\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches '\n",
    "                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),\n",
    "                                              total_acc/total_count))\n",
    "            total_acc, total_count = 0, 0\n",
    "            start_time = time.time()\n",
    "\n",
    "def evaluate(dataloader):\n",
    "    model.eval()\n",
    "    total_acc, total_count = 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for idx, (label, text, offsets) in enumerate(dataloader):\n",
    "            predicted_label = model(text, offsets)\n",
    "            loss = criterion(predicted_label, label)\n",
    "            total_acc += (predicted_label.argmax(1) == label).sum().item()\n",
    "            total_count += label.size(0)\n",
    "    return total_acc/total_count"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# 启动训练\n",
    "for epoch in range(1, EPOCHS + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    train(train_dataloader)\n",
    "    accu_val = evaluate(valid_dataloader)\n",
    "    # if total_accu is not None and total_accu > accu_val:\n",
    "    #   scheduler.step()\n",
    "    # else:\n",
    "    total_accu = accu_val\n",
    "    print('-' * 59)\n",
    "    print('| end of epoch {:3d} | time: {:5.2f}s | '\n",
    "          'valid accuracy {:8.3f} '.format(epoch,time.time() - epoch_start_time,accu_val))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 27,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shape of embedded torch.Size([18966, 64])\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "number of dims don't match in permute",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_11332/1879018591.py\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m      2\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mepoch\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mrange\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mEPOCHS\u001B[0m \u001B[1;33m+\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      3\u001B[0m     \u001B[0mepoch_start_time\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mtime\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mtime\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 4\u001B[1;33m     \u001B[0mtrain\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mtrain_dataloader\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m      5\u001B[0m     \u001B[0maccu_val\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mevaluate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvalid_dataloader\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      6\u001B[0m     \u001B[1;31m# if total_accu is not None and total_accu > accu_val:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_11332/1756616691.py\u001B[0m in \u001B[0;36mtrain\u001B[1;34m(dataloader)\u001B[0m\n\u001B[0;32m      9\u001B[0m     \u001B[1;32mfor\u001B[0m \u001B[0midx\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m(\u001B[0m\u001B[0mlabel\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtext\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0moffsets\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32min\u001B[0m \u001B[0menumerate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mdataloader\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     10\u001B[0m         \u001B[0moptimizer\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mzero_grad\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 11\u001B[1;33m         \u001B[0mpredicted_label\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mtext\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     12\u001B[0m         \u001B[0mloss\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcriterion\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mpredicted_label\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mlabel\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     13\u001B[0m         \u001B[0mloss\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1049\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1050\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1051\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1052\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1053\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_11332/2206131126.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, text)\u001B[0m\n\u001B[0;32m     12\u001B[0m         \u001B[0membedded\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0membedding\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mtext\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;31m# [sent len, batch size, emb dim]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     13\u001B[0m         \u001B[0mprint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m'shape of embedded'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0membedded\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 14\u001B[1;33m         \u001B[0membedded\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0membedded\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mpermute\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;36m2\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;31m# [batch size, sent len, emb dim]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     15\u001B[0m         \u001B[1;31m# 逐列求平均\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     16\u001B[0m         \u001B[0mpooled\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mF\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mavg_pool2d\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0membedded\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m(\u001B[0m\u001B[0membedded\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msqueeze\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;31m# [batch size, embedding_dim]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mRuntimeError\u001B[0m: number of dims don't match in permute"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 启动测试\n",
    "print('Checking the results of test dataset.')\n",
    "accu_test = evaluate(test_dataloader)\n",
    "print('test accuracy {:8.3f}'.format(accu_test))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## RNN模型\n",
    "\n",
    "- 下面我们尝试把模型换成一个**recurrent neural network** (RNN)。\n",
    "RNN经常会被用来encode一个sequence\n",
    "\n",
    "     $$h_t = \\text{RNN}(x_t, h_{t-1})$$\n",
    "\n",
    "- 我们使用最后一个hidden state $h_T$来表示整个句子。\n",
    "\n",
    "- 然后我们把$h_T$通过一个线性变换$f$，然后用来预测句子的情感。\n",
    "\n",
    "![](assets/sentiment1.png)\n",
    "\n",
    "![](assets/sentiment7.png)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "#模型构建\n",
    "class RNNModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, output_dim, hidden_size, dropout):\n",
    "        super(RNNModel,self).__init__()\n",
    "        # TODO 此处有问题\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim) # 单词变成词向量\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_size)\n",
    "        self.linear = nn.Linear(hidden_size, output_dim)\n",
    "        self.dropout = nn.Dropout(dropout) # 每个位置\n",
    "\n",
    "    def forward(self, text):\n",
    "        embedded = self.embedding(text) # [sent len, batch size, emb dim]\n",
    "\n",
    "        embedded = self.dropout(embedded)\n",
    "        print('shape of embedded', embedded.shape) # TODO 此处有问题\n",
    "        output, (hidden, cell) = self.lstm(embedded) # input must have 3 dimensions, got 2\n",
    "        # hidden: 1 * batch_sze * hidden_size\n",
    "        hidden = self.dropout(hidden.squeenze())\n",
    "        return self.linear(hidden)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## CNN模型"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters,\n",
    "                 filter_sizes, output_dim, dropout, pad_idx):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)\n",
    "        self.convs = nn.ModuleList([\n",
    "                                    nn.Conv2d(in_channels = 1, out_channels = n_filters,\n",
    "                                              kernel_size = (fs, embedding_dim))\n",
    "                                    for fs in filter_sizes\n",
    "                                    ])\n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, text):\n",
    "        text = text.permute(1, 0) # [batch size, sent len]\n",
    "        embedded = self.embedding(text) # [batch size, sent len, emb dim]\n",
    "        embedded = embedded.unsqueeze(1) # [batch size, 1, sent len, emb dim]\n",
    "        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
    "\n",
    "        #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n",
    "\n",
    "        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
    "\n",
    "        #pooled_n = [batch size, n_filters]\n",
    "\n",
    "        cat = self.dropout(torch.cat(pooled, dim=1))\n",
    "\n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "\n",
    "        return self.fc(cat)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}